-
Notifications
You must be signed in to change notification settings - Fork 493
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for QAT + LoRA #1931
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1931
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 95961d4 with merge base abdb5a4 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
e20e891
to
d09c71f
Compare
d09c71f
to
1a48a20
Compare
Hey @andrewor14, I was hacking around with LoRA/QLoRA + INT8 mixed-precision and came across this PR of yours. I realized what we are trying to achieve is quite similar. Components of
Since the base weight is direct children of
I have a POC here main...gau-nernst:qlora (you can focus on |
Hi @gau-nernst, yeah I agree we can make the base weight more flexible, then we won't need to create a new class every time we need to extend lora functionality. cc @ebsmothers to see your thoughts on extending For the separate recipe, I discussed this with @ebsmothers recently and I think it's torchtune's recipe organization philosophy to keep them separate, so QAT functionality won't complicate the original lora recipe. |
1a48a20
to
57847fc
Compare
4c1b14f
to
9faac01
Compare
b34da3b
to
7a600cc
Compare
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1931 +/- ##
===========================================
- Coverage 67.29% 24.40% -42.89%
===========================================
Files 318 325 +7
Lines 17646 18498 +852
===========================================
- Hits 11874 4515 -7359
- Misses 5772 13983 +8211 ☔ View full report in Codecov by Sentry. 🚨 Try these New Features:
|
@ebsmothers Any comments? Does this look good to you? |
Hey @andrewor14 sorry for the delay and thanks for your patience here. We are doing planning this week so my available bandwidth for reviewing this has taken a hit. I promise to get to it by Friday at the latest |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @andrewor14 for the PR, and for your patience in the review process! I left a number of comments but no major concerns from my side.
Regarding the two discussion points raised by @gau-nernst previously:
(1) I am open to discussing whether we should change how we expose self.weight
in LoRALinear
from just nn.Parameter
to nn.Linear
. I agree that the latter would be more module-swap friendly, but (for better or for worse) we also do not really design things to have module-swap-based methods as a first-class citizen. Also I think it is not for free, as it means that the key names of the base linear weight will now have an extra module name in between for LoRALinear
(e.g. q_proj.base_weight.weight
instead of q_proj.weight
). On the surface this is quite trivial, but from a checkpointing perspective it's really nice that we currently have an exact match between nn.Linear and LoRALinear and hence can load weights from one directly into the other (with strict=False
of course). So making this simple change would actually have a pretty big blast radius across all our various checkpointing logic. Lmk if these points make sense, happy to discuss further with you.
(2) I think Andrew's separation of QAT full and QAT LoRA into separate recipes makes sense and aligns with what we do for our other recipes. If you are (as I am) a bit concerned about the proliferation of many recipes with similar functionality, I have two comments: (1) we will be making a dedicated effort to start simplifying and consolidating the recipe files somewhat so that they look more like their earlier versions. And (2) I think we now have enough recipes that we can actually consider e.g. recipes/qat
, recipes/knowledge_distillation
, etc. subdirectories. Neither of these will happen tomorrow, but I would like both to happen eventually.
- Checkpointing. Model weights are checkpointed both at the end of each epoch and at the end of | ||
training. Currently we checkpoint both the adapter weights (trainable params only) and the | ||
complete merged weights (adapter weights added back to the base model). For more details |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm I think we may need to update our lora_finetune_distributed.py docstring as well.. really we should make it clear that this behavior can be disabled if save_adapter_weights_only=True
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds good, let's fix these recipes in a separate PR
self._sampler.set_epoch(curr_epoch) | ||
|
||
pbar = tqdm(total=self._steps_per_epoch, disable=not (rank == 0)) | ||
for idx, batch in enumerate(self._dataloader): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So no option to wait N steps before enabling fake quant in this recipe? Any particular reason for that? (To clarify I'm not saying we should add it, mainly just curious)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not super straightforward to add it right now because this recipe uses the new general FakeQuantizedLinear
, as opposed to the specific Int8DynActInt4WeightLinear
class used by the qat_distributed
recipe. I think we can add it separately
@ebsmothers Regarding
In my proof-of-concept above, I handle this by adding the following hooks def load_state_dict_pre_hook(module, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
if isinstance(module, LoRALinear):
state_dict[f"{prefix}base.weight"] = state_dict.pop(f"{prefix}weight")
self.register_load_state_dict_pre_hook(load_state_dict_pre_hook)
def state_dict_post_hook(module, state_dict, prefix, local_metadata):
if isinstance(module, LoRALinear):
state_dict[f"{prefix}weight"] = state_dict.pop(f"{prefix}base.weight")
self.register_state_dict_post_hook(state_dict_post_hook) From my testing it seems sufficient, though I might not cover all edge cases (FSDP2?) We can discuss more in a separate issue/PR if you are open to it, so as not to hijack this PR about QAT + LoRA 😄. The main benefit is ease of injecting custom logic, such as QAT for this PR, INT8 matmul for #1552, or even FP8 matmul in the future. You probably know better than me what are the potential issues, but I think we can try to see if those can be handled nicely. |
@gau-nernst personally I have a bit of an aversion to state dict hooks as @pbontrager can attest 😅. Mainly I find that they make code really hard to debug. Correct usage of modules having state dict hooks generally requires that a module has its state dict called exactly once and submodules are not accessed or modified in any other way. And if either of these constraints are not satisfied the user will get a very non-obvious error about some missing attribute and it won't be at all clear where to go to fix it. But I agree with your point about consolidating the discussion elsewhere (sounds like this PR wouldn't benefit as much from modifying |
131542a
to
618fdce
Compare
724fb11
to
0256e97
Compare
@ebsmothers any other comments? |
c0e9778
to
1181e39
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK a couple more small comments but after that I think this should be good to go. A couple other requests before landing:
- Can you make sure this works with all our usual features (e.g. activation checkpointing, activation offloading)? I already ran with compile myself so no need to worry about that one
- You should also add it to the recipes table in our readme! That way people will know to try it out
@@ -232,3 +235,12 @@ def test_quantized_state_dict(self, dtype): | |||
assert torch.allclose( | |||
lora_linear.weight.quantized_data, lora_linear_reload.weight.quantized_data | |||
) | |||
|
|||
@pytest.mark.skipif(not _torchao_0_7_supported, reason="needs torchao 0.7+") | |||
def test_qat_lora_forward(self, inputs, lora_linear, out_dim) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this!
9f154fe
to
c3c0d4a
Compare
Sounds good. I think I addressed all of the comments and also tested it with the features you mentioned. Please take another look, thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Two more really minor comments. After that it's good to merge. Thanks so much for adding this!
**Summary:** This commit adds a recipe that combines QAT + LoRA, with the main goal of improving final quantized accuracy after training while reducing the memory required for fine-tuning. The new recipe `qat_lora_finetune_distributed` mirrors the existing `lora_finetune_distributed` recipe, which performs only LoRA, and is analogous to the existing `qat_distributed` recipe, which performs only QAT. Helpful code review commands: ``` diff --color recipes/lora_finetune_distributed.py recipes/qat_lora_finetune_distributed.py diff --color recipes/configs/llama3/8B_lora.yaml recipes/configs/llama3/8B_qat_lora.yaml diff --color recipes/configs/llama3_1/8B_lora.yaml recipes/configs/llama3_1/8B_qat_lora.yaml diff --color recipes/configs/llama3_2/1B_lora.yaml recipes/configs/llama3_2/1B_qat_lora.yaml diff --color recipes/configs/llama3_2/3B_lora.yaml recipes/configs/llama3_2/3B_qat_lora.yaml ``` For more context on QAT, please visit pytorch#980 and https://pytorch.org/blog/quantization-aware-training/. **Test Plan** Unit tests: ``` pytest -m integration_test tests/recipes/test_qat_lora_finetune_distributed.py ``` Manual tests: ``` export CUDA_VISIBLE_DEVICES=4,5,6,7 export NCCL_SHM_DISABLE=0 LOG_DIR=/home/andrewor/local/logs/tune/qat_lora tune run --nnodes 1 --nproc_per_node 4 qat_lora_finetune_distributed --config llama3/8B_qat_lora \ batch_size=4 \ quantizer.groupsize=32 \ checkpointer.output_dir="$LOG_DIR" \ metric_logger.output_dir="${LOG_DIR}/metrics" tune run quantize --config quantization \ model._component_=torchtune.models.llama3.llama3_8b \ checkpointer._component_=torchtune.training.FullModelMetaCheckpointer \ checkpointer.checkpoint_dir="$LOG_DIR" \ checkpointer.output_dir="$LOG_DIR" \ checkpointer.checkpoint_files=["meta_model_0.pt"] \ checkpointer.model_type=LLAMA3 \ quantizer._component_=torchtune.training.quantization.Int8DynActInt4WeightQuantizer \ quantizer.groupsize=32 tune run eleuther_eval --config eleuther_evaluation \ batch_size=1 \ model._component_=torchtune.models.llama3.llama3_8b \ checkpointer._component_=torchtune.training.FullModelTorchTuneCheckpointer \ checkpointer.checkpoint_dir="$LOG_DIR" \ checkpointer.output_dir="$LOG_DIR" \ checkpointer.checkpoint_files=["meta_model_0.pt-8da4w"] \ checkpointer.model_type=LLAMA3 \ tokenizer._component_=torchtune.models.llama3.llama3_tokenizer \ tokenizer.path=/tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model \ tasks=[wikitext] \ quantizer._component_=torchtune.training.quantization.Int8DynActInt4WeightQuantizer \ quantizer.groupsize=32 ``` Results: ``` | Tasks |Version|Filter|n-shot| Metric | | Value | |Stderr| |--------|------:|------|------|---------------|---|------:|---|------| |wikitext| 2|none |None |bits_per_byte |↓ | 0.6284|± | N/A| | | |none |None |byte_perplexity|↓ | 1.5458|± | N/A| | | |none |None |word_perplexity|↓ |10.2694|± | N/A| | Tasks |Version|Filter|n-shot| Metric | | Value | |Stderr| |--------|------:|------|------|---------------|---|------:|---|------| |wikitext| 2|none |None |bits_per_byte |↓ | 0.6245|± | N/A| | | |none |None |byte_perplexity|↓ | 1.5416|± | N/A| | | |none |None |word_perplexity|↓ |10.1208|± | N/A| ```
c3c0d4a
to
95961d4
Compare
Summary:
This commit adds a recipe that combines QAT + LoRA, with the main goal of improving final quantized accuracy after training while reducing the memory required for fine-tuning. The new recipe
qat_lora_finetune_distributed
mirrors the existinglora_finetune_distributed
recipe, which performs only LoRA, and is analogous to the existingqat_distributed
recipe, which performs only QAT.Helpful code review commands:
For more context on QAT, please visit #980 and https://pytorch.org/blog/quantization-aware-training/.
Test Plan
Unit tests:
Manual tests:
Results: